#!/usr/bin/env python
# -*-coding=utf-8-*-

import re
import numpy as np

def train_data_format(paths):
    """

    :param paths:
    :return:
    """
    all_text = []
    all_label = []

    for label_index, path in enumerate(paths):
        f = open(path)
        docs = [doc.strip("\n").split() for _, doc in enumerate(f)]
        labels = [label_index] * len(docs)
        all_text.extend(docs)
        all_label.extend(labels)
    assert len(all_label) == len(all_text)

    return all_text, all_label


def save_split_fastext(paths, train_data_path):
    """

    :param paths:
    :param train_data_path:
    :return:
    """

    label_prefix = "__label__"
    all_text, all_label = train_data_format(paths)
    all_label = [label_prefix + str(label) for label in all_label]

    sample_count = len(all_label)
    outf = open(train_data_path, 'w')
    for label, text in zip(all_label, all_text):
        outf.write(" , ".join([label, " ".join(text)]) + "\n")

def label_to_int(label):
    label = label.replace("__label__", "")
    return int(label)


def load_fastext_train_data(path):
    f  = open(path)
    train_x = []
    train_y = []
    for _, doc in enumerate(f):
        doc = doc.strip("\n").split(" , ")
        label = label_to_int(doc[0])
        text = " , ".join(doc[1: ])
        train_y.append(label)
        train_x.append(text)

    return train_x, train_y

def label_to_array(labels, num_classes):
    """

    :param labels:  [int]
    :return:
    """
    sample_count = len(labels)
    train_y = np.zeros((sample_count, num_classes))
    for index in range(sample_count) :
        train_y[index][labels[index]] = 1
    return train_y

def fastext_to_word2vec(in_path, out_path):
    f = open(in_path)
    outf = open(out_path, 'w')
    for i, line in enumerate(f):
        data = line.strip("\n").split(" , ")
        content = "".join(data[1:])+"\n"
        outf.write(content)




if __name__ == '__main__':
    # paths = ["../data/rt-polaritydata/rt-polarity.neg", "../data/rt-polaritydata/rt-polarity.pos"]
    # save_split_fastext(paths, "../data/rt-polaritydata/train.txt")
    # train_x, train_y= load_fastext_train_data("../data/rt-polaritydata/train.txt")
    # print (train_x[0])
    # print(train_y[0])

    fastext_to_word2vec("../data/train.txt", "../data/word2vec.txt")


